"""Adapter utils."""

import torch
from pathlib import Path
from torch import nn, Tensor
from typing import Optional, Dict, Type, Callable, Any, List
from torch.optim import Optimizer
import transformers
from tqdm.autonotebook import trange
from llama_index.embeddings.adapter_utils import BaseAdapter
from llama_index.bridge.langchain import print_text

from sentence_transformers.util import cos_sim


class MyMultipleNegativesRankingLoss(nn.Module):
    """Multiple negatives ranking loss.

    This loss is similar to the one in sentence_transformers,
    but optimized for our own embeddings.

    """

    def __init__(
        self,
        model: BaseAdapter,
        scale: float = 20.0,
        similarity_fct: Optional[Callable] = None,
    ):
        """Define ranking loss."""

        super(MyMultipleNegativesRankingLoss, self).__init__()
        self.model = model
        self.scale = scale
        self.similarity_fct = cos_sim if similarity_fct is None else similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, query_embeds: Tensor, context_embeds: Tensor) -> Tensor:
        """Forward pass."""

        # transform context embeds
        # context_embeds_2 = self.model.forward(context_embeds)
        query_embeds_2 = self.model.forward(query_embeds)

        scores = self.similarity_fct(query_embeds_2, context_embeds) * self.scale
        labels = torch.tensor(
            range(len(scores)), dtype=torch.long, device=scores.device
        )
        return self.cross_entropy_loss(scores, labels)


def train_model(
    model: BaseAdapter,
    data_loader: torch.utils.data.DataLoader,
    device: torch.device,
    epochs: int = 1,
    steps_per_epoch: Optional[int] = None,
    warmup_steps: int = 10000,
    optimizer_class: Type[Optimizer] = torch.optim.AdamW,
    optimizer_params: Dict[str, Any] = {"lr": 2e-5},
    output_path: str = "model_output",
    max_grad_norm: float = 1,
    show_progress_bar: bool = True,
    verbose: bool = False,
    # callback: Callable[[float, int, int], None] = None,
    # scheduler: str = "WarmupLinear",
    # weight_decay: float = 0.01,
    # evaluation_steps: int = 0,
    # save_best_model: bool = True,
    # use_amp: bool = False,  # disable this option for now
    checkpoint_path: Optional[str] = None,
    checkpoint_save_steps: int = 500,
    # checkpoint_save_total_limit: int = 0,
) -> None:
    """Train model."""

    model.to(device)
    # TODO: hardcode loss now, make customizable later
    loss_model = MyMultipleNegativesRankingLoss(model=model)
    loss_model.to(device)

    # prepare optimizer/scheduler
    param_optimizer = list(model.named_parameters())
    optimizer_grouped_parameters: List[Dict[str, Any]] = [
        {
            "params": [p for n, p in param_optimizer],
        },
    ]
    optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
    if steps_per_epoch is None or steps_per_epoch == 0:
        steps_per_epoch = len(data_loader)
    num_train_steps = int(steps_per_epoch * epochs)
    scheduler_obj = transformers.get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_steps
    )

    if verbose:
        print_text("> Prepared optimizer, scheduler, and loss model.\n", color="blue")

    global_step = 0
    data_iterator = iter(data_loader)

    # if checkpoint_path is specified, create if doesn't exist
    if checkpoint_path is not None:
        Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

    for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar):
        training_steps = 0
        loss_model.zero_grad()
        loss_model.train()
        for _ in trange(
            steps_per_epoch,
            desc="Iteration",
            smoothing=0.05,
            disable=not show_progress_bar,
        ):
            try:
                data = next(data_iterator)
            except StopIteration:
                data_iterator = iter(data_loader)
                data = next(data_iterator)

            query, context = data
            context = context.to(device)
            query = query.to(device)

            loss_value = loss_model(query, context)
            if verbose:
                print_text(
                    f"> [Epoch {epoch}] Current loss: {loss_value}\n", color="blue"
                )
            loss_value.backward()
            torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
            optimizer.step()

            optimizer.zero_grad()

            scheduler_obj.step()

            training_steps += 1
            global_step += 1

            # TODO: skip eval for now
            if checkpoint_path is not None and global_step % checkpoint_save_steps == 0:
                full_ck_path = Path(checkpoint_path) / f"step_{global_step}"
                model.save(str(full_ck_path))

    if verbose:
        print_text(f"> Finished training, saving to {output_path}\n", color="blue")

    # save model
    model.save(output_path)
